[mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion#93240
Merged
[mlir][spirv] Add vector.interleave to spirv.VectorShuffle conversion#93240
Conversation
Member
|
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Angel Zhang (angelz913) ChangesAdd Full diff: https://github.com/llvm/llvm-project/pull/93240.diff 1 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..b86ebe1a4bb54 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -578,6 +578,44 @@ struct VectorShuffleOpConvert final
}
};
+struct VectorInterleaveOpConvert final
+ : public OpConversionPattern<vector::InterleaveOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Check the source vector type
+ auto sourceType = interleaveOp.getSourceVectorType();
+ if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "unsupported source vector type");
+ }
+
+ // Check the result vector type
+ auto oldResultType = interleaveOp.getResultVectorType();
+ Type newResultType = getTypeConverter()->convertType(oldResultType);
+ if (!newResultType)
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "unsupported result vector type");
+
+ // Interleave the indices
+ int n = sourceType.getNumElements();
+ auto seq = llvm::seq<int64_t>(2 * n);
+ auto indices = llvm::to_vector(llvm::map_range(
+ seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+
+ // Emit a SPIR-V shuffle.
+ rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+ interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
+ rewriter.getI32ArrayAttr(indices));
+
+ llvm::errs() << "vector.interleave to spirv.VectorShuffle succeeded\n";
+
+ return success();
+ }
+};
+
struct VectorLoadOpConverter final
: public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -821,17 +859,15 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorInterleaveOpConvert, VectorSplatPattern,
+ VectorLoadOpConverter, VectorStoreOpConverter>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
PatternBenefit(2));
-
- // Need this until vector.interleave is handled.
- vector::populateVectorInterleaveToShufflePatterns(patterns);
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
|
Member
|
@llvm/pr-subscribers-mlir-spirv Author: Angel Zhang (angelz913) ChangesAdd Full diff: https://github.com/llvm/llvm-project/pull/93240.diff 1 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index c2dd37f481466..b86ebe1a4bb54 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -578,6 +578,44 @@ struct VectorShuffleOpConvert final
}
};
+struct VectorInterleaveOpConvert final
+ : public OpConversionPattern<vector::InterleaveOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InterleaveOp interleaveOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ // Check the source vector type
+ auto sourceType = interleaveOp.getSourceVectorType();
+ if (sourceType.getRank() != 1 || sourceType.isScalable()) {
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "unsupported source vector type");
+ }
+
+ // Check the result vector type
+ auto oldResultType = interleaveOp.getResultVectorType();
+ Type newResultType = getTypeConverter()->convertType(oldResultType);
+ if (!newResultType)
+ return rewriter.notifyMatchFailure(interleaveOp,
+ "unsupported result vector type");
+
+ // Interleave the indices
+ int n = sourceType.getNumElements();
+ auto seq = llvm::seq<int64_t>(2 * n);
+ auto indices = llvm::to_vector(llvm::map_range(
+ seq, [n](int i) { return (i % 2 ? n : 0) + i / 2; }));
+
+ // Emit a SPIR-V shuffle.
+ rewriter.replaceOpWithNewOp<spirv::VectorShuffleOp>(
+ interleaveOp, newResultType, adaptor.getLhs(), adaptor.getRhs(),
+ rewriter.getI32ArrayAttr(indices));
+
+ llvm::errs() << "vector.interleave to spirv.VectorShuffle succeeded\n";
+
+ return success();
+ }
+};
+
struct VectorLoadOpConverter final
: public OpConversionPattern<vector::LoadOp> {
using OpConversionPattern::OpConversionPattern;
@@ -821,17 +859,15 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
- VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
+ VectorInterleaveOpConvert, VectorSplatPattern,
+ VectorLoadOpConverter, VectorStoreOpConverter>(
typeConverter, patterns.getContext(), PatternBenefit(1));
// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
patterns.add<VectorReductionToFPDotProd>(typeConverter, patterns.getContext(),
PatternBenefit(2));
-
- // Need this until vector.interleave is handled.
- vector::populateVectorInterleaveToShufflePatterns(patterns);
}
void mlir::populateVectorReductionToSPIRVDotProductPatterns(
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
5a54b60 to
cdc9def
Compare
kuhar
requested changes
May 24, 2024
kuhar
reviewed
May 27, 2024
c006ee5 to
cf47613
Compare
kuhar
reviewed
May 27, 2024
24c6d24 to
aed3118
Compare
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
Co-authored-by: Jakub Kuderski <kubakuderski@gmail.com>
aed3118 to
9a688a7
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
vector.interleavetospirv.VectorShuffleconversion,vector.interleavetovector.shuffleconversion frompopulateVectorToSPIRVPatternsand CMake/Bazel dependencies